As discussed previously, when we shift optimization parameters from vectors to matrices and adopt the more matrix-appropriate spectral norm constraint, the Muon optimizer emerges naturally. Further, we considered the steepest descent direction under additional orthogonal constraints, dividing the discussion into square and non-square matrices. The square matrix case was resolved in the previous article, but the non-square case remained open.
The goal of this article is to complete the solution for the non-square case, fully resolving optimization under orthogonal constraints.
Problem Formulation#
Let us briefly review results from the previous article "Manifold Gradient Descent: 2. Muon with Orthogonal Constraints". The objective to solve is:
where $\boldsymbol{W},\boldsymbol{\Phi}\in\mathbb{R}^{n\times m}(n \geq m)$, and $\Vert\cdot\Vert_2$ is the spectral norm. Based on the "first-order approximation suffices" principle, this simplifies to:
The set of all $\boldsymbol{\Phi}$ satisfying $\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0}$ is called the "tangent space" of $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$. In the previous article, we derived the general solution form:
where $\boldsymbol{X}\in\mathbb{R}^{m\times m}$ is an undetermined symmetric matrix.
The remaining challenge is providing a method to compute the symmetric matrix $\boldsymbol{X}$ such that $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$ is a skew-symmetric matrix. Once solved, the corresponding $\boldsymbol{\Phi}$ is naturally the optimal solution. For $n=m$, we already obtained the closed-form solution $\boldsymbol{X}=-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$; the truly difficult case is $n > m$, also known as the "Stiefel manifold," which constitutes the Open problem noted in "Orthogonal manifold".
Equation Transformation#
In essence, our current task is to solve the equation system:
When $n=m$, $\boldsymbol{W}^{\top}$ can be directly absorbed into $\msign$, simplifying the solution. However, for $n > m$, such absorption is not possible, constituting the core difficulty. The author posits that for $n > m$, no simple explicit solution exists, so we seek numerical algorithms.
According to the definition $\msign(\boldsymbol{M})=\boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}$, we can write:
where $\boldsymbol{Q} = ((\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}))^{1/2}$. In this new notation, the equation system becomes:
Left- and right-multiplying by $\boldsymbol{Q}$ yields:
where $\boldsymbol{Q}$ also satisfies:
Iterative Solution#
The author's approach is to start from an initial guess $\boldsymbol{X}$, substitute into Equation (5) to obtain $\boldsymbol{Q}$, then substitute $\boldsymbol{Q}$ into equation system (4) to solve for a new $\boldsymbol{X}$, iterating until convergence. Given $\msign$, Equation (5) can be computed explicitly, so the sole difficulty lies in solving equation system (4).
We can rearrange Equation (4):
Given $\boldsymbol{Q}$, this is a linear equation system in $\boldsymbol{X}$, known as the "continuous Lyapunov equation", or a special case of the "Sylvester equation". If using CPU computation, Scipy includes a built-in function scipy.linalg.solve_continuous_lyapunov that can be called directly.
Regarding initial value selection, we can consider the square matrix solution $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$, which constitutes a natural transition from square to non-square matrices. We can also examine the reasonableness of the initial value $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$ through another equivalent form of Equation (4):
Thus, the accuracy of $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$ depends on the commutativity of $[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}$ and $\boldsymbol{Q}$. The closer they are to commuting, the more accurate $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$ becomes. However, subsequent experimental results show that our iterative algorithm is not particularly sensitive to initial values; even initializing with an all-zero matrix poses little problem.
Implementation Considerations#
As mentioned, Scipy includes a built-in Lyapunov equation solver, allowing direct calls without concern for the solution process. However, this is limited to CPU-based Scipy. The author notes that neither Torch nor Jax have equivalent functions, so for GPU computation, one must "self-reliance."
There are two approaches to programmatically solve Equation (4). The first follows the approach in "What Can the Matrix Sign Function mcsgn Compute?", using $\newcommand{mcsgn}{\mathop{\text{mcsgn}}}\mcsgn$ (not $\msign$) to solve:
The second is based on SVD, a method we already used when computing $\msign$ gradients in "Derivative of msign". Combining with Equation (4), we reintroduce it here. From $\boldsymbol{Q}$'s definition, it is symmetric positive definite, thus can be eigendecomposed as $\boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$, where $\boldsymbol{V}$ is orthogonal and $\boldsymbol{\Sigma}=\mathop{\text{diag}}(\sigma_1,\cdots,\sigma_m)$ is diagonal. Substituting into Equation (4) yields:
The left side can be expressed as $(\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V})\otimes \boldsymbol{S}$, where $\otimes$ denotes the Hadamard product, and $\boldsymbol{S}_{i,j} = \sigma_i + \sigma_j$. Consequently, we can solve:
where $\oslash$ denotes the Hadamard quotient. An interesting aspect here is that performing eigendecomposition on $\boldsymbol{Q}$ is essentially equivalent to performing SVD on $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$, and performing SVD on $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$ can also be used to compute $\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})$. Thus, a single SVD can compute both $\msign$ and the solution to Equation (4).
Both approaches have distinct characteristics. Approach 1 requires computing $\msign$ for an $m\times m$ matrix, then $\mcsgn$ for a $2m\times 2m$ matrix. Although both can be efficiently computed via Newton-Schulz iterations, the cost is non-trivial. Additionally, we must select coefficients that ensure convergence and high accuracy (recommending results from "Newton-Schulz Iteration for msign Operator (Part 2)"); otherwise, neither $\mcsgn$ nor $\msign$ will converge, let alone $\boldsymbol{X}$.
Approach 2 requires SVD. Although SVD has higher computational complexity and often forces FP32 precision, in this problem, each iteration only requires one SVD to simultaneously compute $\msign$ and $\boldsymbol{X}$, so overall efficiency is not too poor. If the number of orthogonally constrained matrix parameters is small, SVD may be the simplest choice.
Related Approaches#
Prior to this article, @leloy proposed two heuristic solution methods for the original objective in his blog post "Heuristic Solutions for Steepest Descent on the Stiefel Manifold". Here, "heuristic" means that in most cases, it yields a reasonably good solution but cannot guarantee optimality. Let's examine these as well.
The first method can be described as purely geometric. First, we define a projection operator:
It can be verified that $\boldsymbol{W}^{\top}\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M})$ is always skew-symmetric, meaning $\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M})$ always lies in the tangent space. Thus, we consider it as the projection of an arbitrary matrix $\boldsymbol{M}$ onto the tangent space of $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$.
Starting from gradient $\boldsymbol{G}$, $\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M})$ is certainly in the tangent space. However, we know Muon's update must be an orthogonal matrix (when full-rank), while $\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M})$ may not be orthogonal. Thus, we can use $\msign$ to find the nearest orthogonal matrix, i.e., $\msign(\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M}))$. Yet, after $\msign$, it may no longer be in the tangent space, so we can project it back, then find the nearest orthogonal matrix again, iterating:
This is @leloy's first approach: alternating projections onto the tangent space and orthogonal space until convergence, which is quite intuitive. Under relatively random conditions, it closely approximates the optimal solution, sometimes even accurate to four decimal places, leading the author initially to believe it was the exact solution. However, after searching, cases with sufficiently large deviations from the optimal solution were found, confirming this is coincidental, not optimal.
The second method can be termed line search. Specifically, when $n > m$, we can consider extending $\boldsymbol{W}$ to a standard $n\times n$ orthogonal matrix $[\boldsymbol{W},\overline{\boldsymbol{W}}]$, then decompose the desired $\boldsymbol{\Phi}$ into $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$ and $\overline{\boldsymbol{W}}{}^{\top}\boldsymbol{\Phi}$ components. @leloy then makes a greedy approximation: first find the optimal solution for $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$, then for $\overline{\boldsymbol{W}}{}^{\top}\boldsymbol{\Phi}$, introducing a line search between the two to improve accuracy.
This procedure yields a reasonably accurate approximation, guaranteed to lie in the tangent space and satisfy orthogonality. The solution process requires computing spectral norm, $\msign$, and Cholesky decomposition; details can be found in the author's article. Additionally, when $m=2$, theoretically it can search for the optimal solution, because $2\times 2$ skew-symmetric matrices have only one degree of freedom, matching the single degree of freedom in line search.
Experimental Tests#
Below, we test the above methods in NumPy, with the primary goal of verifying correctness, so we directly implement $\msign$ and $\mcsgn$ using singular value decomposition and eigendecomposition.
import numpy as np
import scipy as sp
def mcsgn(x):
"""Precise computation of mcsgn via eigendecomposition
"""
s, v = np.linalg.eig(x)
return v @ np.diag(np.sign(s)) @ np.linalg.inv(v)
def msign(g):
"""Precise computation of msign via SVD
"""
u, s, vh = np.linalg.svd(g, full_matrices=False)
return u @ np.diag(np.sign(s)) @ vh
def sym(x):
"""Symmetrization
"""
return (x + x.T) * 0.5
def skew(x):
"""Skew-symmetrization
"""
return (x - x.T) * 0.5
def proj(g, w):
"""Projection onto orthogonal tangent space
"""
return g - w @ sym(w.T @ g)
def jianlin_by_mcsgn(g, w, steps=20):
"""Iterative construction via mcsgn (author's method)
"""
n, m = g.shape
x = -sym(w.T @ g)
for i in range(1, steps + 1):
phi = msign(z := g + w @ x)
print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
if i == steps:
return phi
q = z.T @ phi
x = mcsgn(np.block([[-q, -sym(q @ w.T @ g)], [np.zeros_like(q), q]]))[:m, m:]
# x = -2 * sp.linalg.solve_continuous_lyapunov(q, sym(q @ w.T @ g))
def jianlin_by_svd(g, w, steps=20):
"""Iterative construction via SVD (author's method)
"""
x = -sym(w.T @ g)
for i in range(1, steps + 1):
u, s, vh = np.linalg.svd(z := g + w @ x, full_matrices=False)
phi = (u * np.sign(s)) @ vh
print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
if i == steps:
return phi
x = -2 * vh.T @ (vh @ sym(z.T @ phi @ w.T @ g) @ vh.T / (s + s[:, None])) @ vh
def leloy_v1(g, w, steps=20):
"""Alternating projection onto tangent and orthogonal spaces
"""
phi = g
for i in range(1, steps + 1):
phi = msign(proj(phi, w))
print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
return phi
def leloy_v2(g, w, steps=20):
"""Component-wise greedy solution + line search (simplified by author)
"""
n, m = g.shape
taus = np.linspace(0, 1, steps + 2)[1:-1]
p_max, tau_opt, phi_opt = 0, 0, None
for tau in taus:
b = (b := skew(w.T @ g)) * tau / max(np.linalg.norm(b, ord=2), 1e-8)
r = np.linalg.cholesky(np.eye(m) - b.T @ b)
c = msign((np.eye(n) - w @ w.T) @ g @ r) @ r
phi = w @ b + c
print('tau:', tau, ', inner product:', p := (phi * g).sum())
if p > p_max:
p_max, tau_opt, phi_opt = p, tau, phi
print('best inner product:', p_max, ', tau:', tau_opt)
return phi_opt
# Test case 1
w = np.array([[ 0.69453734, -0.26590866, -0.44721806, 0.2753041 ],
[-0.11738148, -0.5588003 , -0.17580748, 0.3218624 ],
[-0.4515288 , -0.23489913, -0.26683152, -0.25739142],
[ 0.02392521, 0.02664689, 0.48423648, 0.6193399 ],
[ 0.45194831, -0.25206333, 0.27654836, -0.60242337],
[ 0.21197332, -0.09174792, 0.24521762, -0.08484317],
[-0.15496767, -0.26446804, -0.34942415, -0.01877318],
[-0.16181251, -0.6474956 , 0.45243263, -0.01776086]])
g = np.array([[-17.85745 , -10.758921 , -2.9583392 , 6.245008 ],
[-28.883093 , 19.772121 , 8.086545 , -21.564013 ],
[ -1.6274693 , -14.96859 , 3.4465332 , 3.1070817 ],
[ -7.8890743 , 1.5304767 , -8.949573 , 9.579629 ],
[ 2.246596 , 14.46572 , 12.8451 , -2.7370298 ],
[ -0.9496974 , 6.9879804 , 2.849277 , 1.1148484 ],
[ -8.115278 , -18.054405 , -0.19287404, 7.0389237 ],
[-15.062008 , -15.02901 , 2.9083247 , 21.706533 ]])
phi1 = jianlin_by_mcsgn(g, w, steps=100)
phi2 = jianlin_by_svd(g, w, steps=100)
phi3 = leloy_v1(g, w, steps=100)
phi4 = leloy_v2(g, w, steps=100)
assert np.allclose(phi1, phi2)
# Test case 2 (random)
w = np.linalg.qr(np.random.randn(100, 50))[0]
g = np.random.randn(100, 50)
phi1 = jianlin_by_mcsgn(g, w, steps=10)
phi2 = jianlin_by_svd(g, w, steps=10)
phi3 = leloy_v1(g, w, steps=10)
phi4 = leloy_v2(g, w, steps=10)
assert np.allclose(phi1, phi2)
For the first set of $\boldsymbol{W},\boldsymbol{G}$ given in the code, the author's method yields an optimal $\tr(\boldsymbol{G}^{\top} \boldsymbol{\Phi})$ approximately 90, and results from $\mcsgn$ and SVD are identical. @leloy's first method yields approximately 70, and the second method approximately 80, both deviating from the optimal solution.
However, the first $\boldsymbol{W},\boldsymbol{G}$ set is an extreme example specifically searched to highlight differences. If we replace with relatively random values, the author's solution and @leloy's first method become very close, requiring fewer iterations (5–10 steps). In such cases, @leloy's second method deviates more from the optimal solution. Readers can construct additional examples for testing.
Extensions and Considerations#
The solution to the original problem is now temporarily concluded. Next, we supplement discussion on several potentially confusing details.
First, for descriptive convenience, the iterative solution process presented earlier carries an implicit assumption: $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$ remains full-rank (rank $m$) throughout; otherwise, matrix $\boldsymbol{S}$ would have zero components, and $\oslash\boldsymbol{S}$ becomes problematic. However, this difficulty is not fundamental because Equation (1) must have a solution. Thus, when encountering zero denominators, numerators must also be zero. Therefore, we simply need to replace zero components of $\boldsymbol{S}$ with a small positive number to obtain correct results.
From a numerical computation perspective, we rarely encounter singular values that are exactly zero, so this concern can be largely ignored, assuming $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$ is full-rank by default. Under this default assumption, the retraction operation becomes simple because:
By Stiefel manifold definition, the right-hand first term is $\boldsymbol{I}$; by tangent space condition, the second term is $\boldsymbol{0}$; finally, when full-rank, $\msign$ yields a Stiefel manifold matrix, so the third term is $\eta^2 \boldsymbol{I}$, resulting in $(1+\eta^2)\boldsymbol{I}$. Simply dividing by $\sqrt{1+\eta^2}$ achieves retraction:
At this point, the author should notice a deeper issue: whether for relatively simple orthogonal manifolds or more complex Stiefel manifolds, what computational precision should we use? "Orthogonal" is an exact quantitative constraint; $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$ contains $m(m+1)/2$ equality constraints. It is foreseeable that using the above iteration in low precision will gradually cause severe deviation from orthogonality, not to mention errors in solving $\boldsymbol{\Phi}$.
Therefore, the author believes that unless we periodically apply orthogonalization operations to parameters (i.e., $\boldsymbol{W}\leftarrow\msign(\boldsymbol{W})$) to pull them back onto the orthogonal manifold, the computational precision for the solution process should start at FP32. Considering that typically few parameters require orthogonal constraints, this generally does not constitute excessive cost.
Summary#
This article extends the previous article's "Muon + orthogonal manifold" to the more general "Muon + Stiefel manifold," primarily discovering an iterative algorithm for solving the corresponding update direction.
Original Article: Su Jianlin. Manifold Gradient Descent: 3. Muon on Stiefel Manifold. Scientific Spaces.
How to cite this translation:
BibTeX: